library(MASS)
library(monomvn)
library(ggplot2)
library(ggthemes)
library(tidyverse)
library(dplyr)
library(lfe)

boot_iid <- function(x, est_fun, pmax, pmin, y_est, nsims){
  samp <- sample(1:nrow(pmax),nrow(pmax), replace=T)
  est_fun(pmax[samp,], pmin[samp,], y_est[samp], nsims=nsims) 
}



get_resids <- function(in_max, in_min, y_est){
  keep <- xor(in_max,  in_min)
  to_return <- rep(0, length(y_est))
  to_return[in_max] <- 1/sum(in_max)*y_est[in_max]
  to_return[in_min] <- -1/sum(in_min)*y_est[in_min]
  to_return
}



get_treat_mcseq_from_v_ls <- function(pmax, pmin, y_est, nsims){
  to_fill <- matrix(NA, nrow=length(y_est), ncol=nsims)  
  for(j in 1:nsims){
    to_fill[,j] <- get_resids(in_max=pmax[,j] > 0, in_min=pmin[,j] > 0, y_est=y_est)
  }
  vec <- rowMeans(to_fill)
  sum(vec)
}



get_mcseq_ls_treat <- function(t_split, t_est, y_split, y_est, nsims, nboot,q){
  split_mod <- felm(y_split~t_split)
  beta <- coef(split_mod)[-1]
  .Sigma <- vcov(split_mod)[-1,-1]
  
  .Sigma <- .Sigma
  beta <- beta
  
  betas <- mvrnorm(n = nsims, mu=beta, Sigma = .Sigma, tol = 1e-6, empirical = FALSE, EISPACK = FALSE)
  ypreds <- t_est %*% t(betas)
  
  
  pmin <- apply(ypreds, 2, function(x){x <= quantile(x, q)})/nsims
  pmax <- apply(ypreds, 2, function(x){x >= quantile(x, 1-q)})/nsims
  
  est <- get_treat_mcseq_from_v_ls(pmax, pmin, y_est,nsims=nsims) 
  boots <- unlist(lapply(1:nboot, boot_iid, est_fun=get_treat_mcseq_from_v_ls, pmax=pmax, 
                         pmin=pmin, y_est=y_est, nsims=nsims))
  
  list(est=est, v=var(boots))
}



est_mcseq <- function(treat_types, y, treatment, nfold=2, nboot=100, est_fun, q, nsims){
  folds <- sample(1:nfold, nrow(treat_types), replace=T)
  
  ests <- rep(NA, nfold)
  vars <- rep(NA, nfold)
  for(fold in 1:nfold){
    mod <- est_fun(t_split = treat_types[folds != fold,], t_est = treat_types[folds == fold,],
                   y_split = y[folds != fold], y_est = y[folds == fold],  
                   nsims=nsims,  q=q, nboot=nboot)
    
    
    ests[fold] <- mod$est
    vars[fold] <- mod$v
  }
  
  return(list(est=mean(ests), se=sqrt(mean(vars)), first_fold_est=ests[[1]], 
              first_fold_v = vars[[1]], second_fold_est=ests[[1]], second_fold_v = vars[[1]]))
}






est_max <- function(y, x,q){
  mod <- lm(y~x)
  preds <- predict(mod)
  
  mean(preds[preds > quantile(preds,1-q)]) - 
    mean(preds[preds < quantile(preds,1-q)])
  
}

sim <- function(n, k, rho=.5, sd_epsilon=1, sd_eta=1, sd_beta=1, q=.3, null=F){
  epsilon <- matrix(rnorm(n*k, 0, sd_epsilon), ncol=k)
  eta <- rnorm(n, 0, sd_eta)
  if(null){
    beta <- rep(0, k)
  }else{
    beta <- rnorm(k,0,sd=sd_beta)
  }
  
  .Sigma <- matrix(rho, ncol=k, nrow=k)
  diag(.Sigma) <- 1
  
  t <- mvrnorm(n=n, mu=rep(0,k), Sigma=.Sigma)
  
  e_y <- t %*% beta
  upper <- mean(e_y[e_y >= quantile(e_y, 1-q)])
  lower <- mean(e_y[e_y <= quantile(e_y, q)])
  if(!null){
    e_y <- e_y/(upper - lower)
  }
  y <- e_y + eta
  
  
  mcseq <- est_mcseq(treat_types=t, y=y, nfold=2, nboot=100, nsims=100,est_fun=get_mcseq_ls_treat, q=q)
  
  pc <- prcomp(t, retx=T)
  pca_est <- est_max(y,pc$x[,1],q=q)
  true_max <- est_max(y,t[,which.max(beta)],q=q)
  beta_hat <- coef(lm(y~t))
  
  emp_max <- est_max(y,t[,which.max(abs(beta_hat[-1]))],q=q)
  
  means_est <- est_max(y=y, x=rowMeans(t), q=q)
  
  c(mcseq$est, true_max, emp_max, pca_est, means_est,mcseq$se,
    mcseq$first_fold_est, 
    mcseq$first_fold_v , mcseq$second_fold_est, mcseq$second_fold_v)
}

pars <- expand.grid(n=seq(200,1000, by=100), k=c(2,10,50), rho=c(0,.5,.9), niter=1:100, null=c(T,F))
ests <- vector(mode='list', length=nrow(pars))
for(i in 1:nrow(pars)){
  ests[[i]] <- sim(n=pars$n[i], k=pars$k[i], rho=pars$rho[i], sd_beta=1, null=pars$null[i])
}

ests <- cbind(pars, do.call(rbind, ests))
ests <- data.frame(ests)
colnames(ests) <- c('n', 'k', 'rho', 'iter','null',
                    'MCSE', 'True Max', 'Empirical Max', 'PCA', 'Means', 'se', 'first_est', 'first_v', 'second_est', 'second_se')

ests$k <- paste0('k=', ests$k)
ests$k <- factor(ests$k, levels=c('k=2', 'k=10', 'k=50'))
ests$rho <- factor(paste0('rho=', ests$rho))

#saveRDS(ests, '~/Dropbox (MIT)/Diss/hdci/Intermediate Files/synth_ests.rds')

library(dplyr)
library(tidyverse)

get_rmse <- function(x){
  sqrt(mean((1-x)^2))
}


by_k_bias <- group_by(ests[,1:10][,-3], n, k, null) %>%
  summarize_all(mean) %>%
  pivot_longer(cols=c('MCSE', 'True Max', 'Empirical Max', 'PCA', 'Means')) %>%
  filter(!(name %in% c('True Max', 'Empirical Max', 'Means')))


by_k_bias_se <- group_by(ests[,1:10][,-3], n, k, null) %>%
  summarize_all(sd) %>%
  pivot_longer(cols=c('MCSE', 'True Max', 'Empirical Max', 'PCA', 'Means')) %>%
  filter(!(name %in% c('True Max', 'Empirical Max', 'Means')))
by_k_bias_se$value <- by_k_bias_se$value/sqrt(300)
max(by_k_bias_se$value)
min(by_k_bias_se$value)

by_k_bias$bias <- by_k_bias$value - 1
by_k_bias$lower <-  by_k_bias$bias - 1.96*by_k_bias_se$value
by_k_bias$upper <-  by_k_bias$bias + 1.96*by_k_bias_se$value
by_k_bias$Model <- by_k_bias$name

pdf('comp_sim_k_bias.pdf', width=8, height=3)
ggplot(by_k_bias[!by_k_bias$null,], aes(x=n, y=bias, color=Model)) + geom_point() + geom_line() + facet_wrap(~k) + theme_few() + 
  xlab('n') + ylab('Bias')
dev.off()

by_rho_bias <- group_by(ests[,1:10], n, rho, null) %>%
  summarize_all(mean) %>%
  pivot_longer(cols=c('MCSE', 'True Max', 'Empirical Max', 'PCA', 'Means'))%>%
  filter(!(name %in% c('True Max', 'Empirical Max', "Means")))

by_rho_bias$bias <- by_rho_bias$value - 1
by_rho_bias$Model <- by_rho_bias$name

pdf('comp_sim_rho_bias.pdf', width=8, height=3)
ggplot(by_rho_bias[by_rho_bias$rho %in% c('rho=0', 'rho=0.5', 'rho=0.9') & !by_rho_bias$null,], aes(x=n, y=bias, color=Model)) + geom_point() + geom_line() + facet_wrap(~rho) + theme_few() + 
  ylab('Bias')
dev.off()

by_k_rmse <- group_by(ests[,1:10], n, k, null) %>%
  summarize_all(get_rmse) %>%
  pivot_longer(cols=c('MCSE', 'True Max', 'Empirical Max', 'PCA', 'Means'))%>%
  filter(!(name %in% c('True Max', 'Empirical Max', 'Means')))


by_k_rmse$rmse <- by_k_rmse$value
by_k_rmse$Model <- by_k_rmse$name

pdf('comp_sim_k_rmse.pdf', width=8, height=3)
ggplot(by_k_rmse[!by_k_rmse$null,], aes(x=n, y=rmse, color=Model)) + geom_point() + geom_line() + facet_wrap(~k) + theme_few() + 
  xlab('n') + ylab('RMSE')+ ylim(0, .8)
dev.off()


by_rho_rmse <- group_by(ests[,1:10], n, rho, null) %>%
  summarize_all(get_rmse) %>%
  pivot_longer(cols=c('MCSE', 'True Max', 'Empirical Max', 'PCA', 'Means'))%>%
  filter(!(name %in% c('True Max', 'Empirical Max', 'Means')))


by_rho_rmse$rmse <- by_rho_rmse$value
by_rho_rmse$Model <- by_rho_rmse$name

pdf('comp_sim_rho_rmse.pdf', width=8, height=3)
ggplot(by_rho_rmse[!by_rho_rmse$null,], aes(x=n, y=rmse, color=Model)) + geom_point() + geom_line() + facet_wrap(~rho) + theme_few() + 
  xlab('n') + ylab('RMSE') + ylim(0, .8)
dev.off()


coverage_rate <- group_by(ests[ests$null,], n, k, rho) %>% 
  summarize(false_pos = 1 - mean(MCSE - 1.96*se < 0))


pdf('false_positive_rates.pdf', width=8, height=3) 
ggplot(coverage_rate, aes(x=n, y=false_pos)) + geom_point()  + facet_wrap(~k) + theme_few() + 
  xlab('n') + ylab('False Positive Rate') #+ ylim(0, .8)
dev.off()

